{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "import yaml\n",
    "from torchvision import transforms, datasets\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import os\n",
    "from sklearn import preprocessing\n",
    "from torch.utils.data.dataloader import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append('../')\n",
    "from models.resnet_base_network import ResNet18"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 512\n",
    "data_transforms = torchvision.transforms.Compose([transforms.ToTensor()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = yaml.load(open(\"../config/config.yaml\", \"r\"), Loader=yaml.FullLoader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = datasets.STL10('/home/thalles/Downloads/', split='train', download=False,\n",
    "                               transform=data_transforms)\n",
    "\n",
    "test_dataset = datasets.STL10('/home/thalles/Downloads/', split='test', download=False,\n",
    "                               transform=data_transforms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input shape: torch.Size([3, 96, 96])\n"
     ]
    }
   ],
   "source": [
    "print(\"Input shape:\", train_dataset[0][0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = DataLoader(train_dataset, batch_size=batch_size,\n",
    "                          num_workers=0, drop_last=False, shuffle=True)\n",
    "\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size,\n",
    "                          num_workers=0, drop_last=False, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cpu' #'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "encoder = ResNet18(**config['network'])\n",
    "output_feature_dim = encoder.projetion.net[0].in_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameters successfully loaded.\n"
     ]
    }
   ],
   "source": [
    "#load pre-trained parameters\n",
    "load_params = torch.load(os.path.join('/home/thalles/PycharmProjects/PyTorch-BYOL/runs/resnet-18_80-epochs/checkpoints/model.pth'),\n",
    "                         map_location=torch.device(torch.device(device)))\n",
    "\n",
    "if 'online_network_state_dict' in load_params:\n",
    "    encoder.load_state_dict(load_params['online_network_state_dict'])\n",
    "    print(\"Parameters successfully loaded.\")\n",
    "\n",
    "# remove the projection head\n",
    "encoder = torch.nn.Sequential(*list(encoder.children())[:-1])    \n",
    "encoder = encoder.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Sequential(\n",
       "    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU(inplace=True)\n",
       "    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "    (4): Sequential(\n",
       "      (0): BasicBlock(\n",
       "        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (1): BasicBlock(\n",
       "        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (5): Sequential(\n",
       "      (0): BasicBlock(\n",
       "        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): BasicBlock(\n",
       "        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (6): Sequential(\n",
       "      (0): BasicBlock(\n",
       "        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): BasicBlock(\n",
       "        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (7): Sequential(\n",
       "      (0): BasicBlock(\n",
       "        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): BasicBlock(\n",
       "        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (8): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LogisticRegression(torch.nn.Module):\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super(LogisticRegression, self).__init__()\n",
    "        self.linear = torch.nn.Linear(input_dim, output_dim)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return self.linear(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "logreg = LogisticRegression(output_feature_dim, 10)\n",
    "logreg = logreg.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_features_from_encoder(encoder, loader):\n",
    "    \n",
    "    x_train = []\n",
    "    y_train = []\n",
    "\n",
    "    # get the features from the pre-trained model\n",
    "    for i, (x, y) in enumerate(loader):\n",
    "        with torch.no_grad():\n",
    "            feature_vector = encoder(x)\n",
    "            x_train.extend(feature_vector)\n",
    "            y_train.extend(y.numpy())\n",
    "\n",
    "            \n",
    "    x_train = torch.stack(x_train)\n",
    "    y_train = torch.tensor(y_train)\n",
    "    return x_train, y_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training data shape: torch.Size([5000, 512]) torch.Size([5000])\n",
      "Testing data shape: torch.Size([8000, 512]) torch.Size([8000])\n"
     ]
    }
   ],
   "source": [
    "encoder.eval()\n",
    "x_train, y_train = get_features_from_encoder(encoder, train_loader)\n",
    "x_test, y_test = get_features_from_encoder(encoder, test_loader)\n",
    "\n",
    "if len(x_train.shape) > 2:\n",
    "    x_train = torch.mean(x_train, dim=[2, 3])\n",
    "    x_test = torch.mean(x_test, dim=[2, 3])\n",
    "    \n",
    "print(\"Training data shape:\", x_train.shape, y_train.shape)\n",
    "print(\"Testing data shape:\", x_test.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test):\n",
    "\n",
    "    train = torch.utils.data.TensorDataset(X_train, y_train)\n",
    "    train_loader = torch.utils.data.DataLoader(train, batch_size=64, shuffle=True)\n",
    "\n",
    "    test = torch.utils.data.TensorDataset(X_test, y_test)\n",
    "    test_loader = torch.utils.data.DataLoader(test, batch_size=512, shuffle=False)\n",
    "    return train_loader, test_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler = preprocessing.StandardScaler()\n",
    "scaler.fit(x_train)\n",
    "x_train = scaler.transform(x_train).astype(np.float32)\n",
    "x_test = scaler.transform(x_test).astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, test_loader = create_data_loaders_from_arrays(torch.from_numpy(x_train), y_train, torch.from_numpy(x_test), y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing accuracy: 59.0875\n",
      "Testing accuracy: 67.525\n",
      "Testing accuracy: 68.4875\n",
      "Testing accuracy: 68.9125\n",
      "Testing accuracy: 69.325\n",
      "Testing accuracy: 69.3875\n",
      "Testing accuracy: 69.7125\n",
      "Testing accuracy: 69.5875\n",
      "Testing accuracy: 69.9375\n",
      "Testing accuracy: 69.6625\n",
      "Testing accuracy: 70.0125\n",
      "Testing accuracy: 70.175\n",
      "Testing accuracy: 70.1125\n",
      "Testing accuracy: 69.975\n",
      "Testing accuracy: 70.0\n",
      "Testing accuracy: 70.05\n",
      "Testing accuracy: 69.9125\n",
      "Testing accuracy: 69.85\n",
      "Testing accuracy: 70.05\n",
      "Testing accuracy: 69.6875\n"
     ]
    }
   ],
   "source": [
    "optimizer = torch.optim.Adam(logreg.parameters(), lr=3e-4)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "eval_every_n_epochs = 10\n",
    "\n",
    "for epoch in range(200):\n",
    "#     train_acc = []\n",
    "    for x, y in train_loader:\n",
    "\n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        \n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()        \n",
    "        \n",
    "        logits = logreg(x)\n",
    "        predictions = torch.argmax(logits, dim=1)\n",
    "        \n",
    "        loss = criterion(logits, y)\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    total = 0\n",
    "    if epoch % eval_every_n_epochs == 0:\n",
    "        correct = 0\n",
    "        for x, y in test_loader:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "\n",
    "            logits = logreg(x)\n",
    "            predictions = torch.argmax(logits, dim=1)\n",
    "            \n",
    "            total += y.size(0)\n",
    "            correct += (predictions == y).sum().item()\n",
    "            \n",
    "        acc = 100 * correct / total\n",
    "        print(f\"Testing accuracy: {np.mean(acc)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "pytorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
